Import packages¶

In [1]:
# Import necessary libraries
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio
import pandas as pd
import plotly.express as px
import plotly, os, joblib
import numpy as np
import pickle

Custom Functions for Generating Plots¶

In [2]:
# Function to generate the plot
def generatePlot(plotName, rows, cols, data, width, height, vertical_spacing, horizontal_spacing,
                 title_font, marker_size, label_size, tick_size, plot_title, line_color):
    # Sorting the data
    df_sorted = data

    # Setting up the subplot structure
    specs = []
    for row in range(1, rows + 1):
        a = []
        for col in range(1, cols + 1):
            a.append({"type": "polar"})
        specs.append(a)

    # Creating the subplots
    fig = make_subplots(rows=rows, cols=cols, vertical_spacing=vertical_spacing,
                        subplot_titles=[i.replace('__', '-') for i in df_sorted.index.tolist()],
                        horizontal_spacing=horizontal_spacing, specs=specs)

    # Adding traces to the subplots
    row = 1
    col = 1
    for model in df_sorted.index:
        name = []
        value = []

        model_score = df_sorted.loc[model]
        for score in model_score.index:
            if score == "model":
                continue
            name.append(score)
            value.append(model_score.loc[score] * 100)

        fig_tem = go.Scatterpolar(r=value, name=model, dtheta=20,
                                 theta=name, fill='toself',
                                 line_color=line_color)
        fig.add_trace(fig_tem,
                      row=row, col=col)

        if col == cols:
            col = 1
            row += 1
        else:
            col += 1

    # Updating layout and annotations
    for i in fig['layout']['annotations']:
        i['font'] = dict(size=title_font, family="Arial", color='black')
        i['borderpad'] = 5

    # Finalizing the layout
    fig.update_layout(width=width, height=height, font_size=label_size, template="plotly_white",
                      font_family="Arial",
                      showlegend=False, margin=dict(t=30, b=15, r=40, l=40,))
    fig.update_layout(title={'text': plot_title, 'y': 0.995, 'x': 0.5,
                            'xanchor': 'center', 'yanchor': 'top',
                            "font_family": "Arial", "font_size": 10})

    # Updating polar axis
    fig.update_polars(radialaxis=dict(
        visible=True, nticks=7, range=[30, 100],
        tickfont=dict(size=tick_size)
    ),
        angularaxis=dict(showticklabels=False, ticks='', linewidth=0.2, showline=True, linecolor='black'))

    # Updating traces
    fig.update_traces(marker=dict(size=marker_size, line_color="black", color=px.colors.sequential.Viridis),
                      selector=dict(type='scatterpolar'))
    fig.update_polars(angularaxis=dict(showticklabels=True))
    return(fig) 
    
    
    # Function to generate spyder plots
def getSpyderPlot(results, cvName, plot_title, line_color):
    # Convert results to DataFrame
    df = pd.DataFrame({'Accuracy': results["Acc"],
                       'Balanced Acc': results["Bal_acc"],
                       'F1': results["F1"],
                       'Recall': results["recall"],
                       'Precision': results["precision"],
                       'Avg precision': results["average_precision"],
                       'roc_auc': results["roc_auc"],
                       "model": results["model"]}, index=results["model"])

    # Sort the DataFrame
    df_sorted = df.sort_values('F1')
    # Generate the spyder plot
    return(generatePlot(cvName, 2, 4, df_sorted, 850, 450, 0.08, 0.09, 10, 4, 9, 9, plot_title, line_color))

Results Folder¶

In [3]:
# Specify the path to the results folder downloaded from GitHub. 
#resultsFolder='/path/to/results/folder'
resultsFolder='/Users/akshay/Desktop/prof_katia_plots/pain-paper/MLcps-paper/MLcps/generateManuPlots/results'

1. CLL Dataset¶

Setup¶

In [4]:
# change directory to the CLL results folder.
# All the plots will be saved here.
os.chdir(os.path.join(resultsFolder,"CLL"))

# load dataset
results_whole = pickle.load(open("results_whole.pickle",'rb'))

Figure S1A¶

In [5]:
getSpyderPlot(results_whole,"result_whole","CLL","#B9E4E8")

2. Cervical Dataset¶

Setup¶

In [6]:
# change directory to the cervical results folder.
# All the plots will be saved here.
os.chdir(os.path.join(resultsFolder,"cervical"))

# load dataset
results_whole = pickle.load(open("results_whole.pickle",'rb'))

Figure 2A¶

In [7]:
getSpyderPlot(results_whole,"result_whole","Cervical","#B9E4E8")

3. TCGA miRNA Dataset¶

Setup¶

In [8]:
# change directory to the TCGA miRNA results folder.
# All the plots will be saved here.
os.chdir(os.path.join(resultsFolder,"TCGA-BRCA_miRNA"))

# load dataset
results_whole = pickle.load(open("results_whole.pickle",'rb'))
results_test = pickle.load(open("results_test.pickle",'rb'))

Figure 1B¶

In [9]:
getSpyderPlot(results_whole,"result_whole","TCGA-BRCA-miRNA","#B9E4E8")

Figure 1C¶

In [10]:
getSpyderPlot(results_whole,"result_test","TCGA-BRCA-miRNA","#8CBFAA")

4. TCGA mRNA Dataset¶

Setup¶

In [11]:
# change directory to the TCGA mRNA results folder.
# All the plots will be saved here.
os.chdir(os.path.join(resultsFolder,"TCGA-BRCA_mRNA"))

# load dataset
results_whole = pickle.load(open("results_whole.pickle",'rb'))
results_test = pickle.load(open("results_test.pickle",'rb'))

Figure 2B¶

In [12]:
getSpyderPlot(results_whole,"result_whole","TCGA-BRCA-mRNA","#B9E4E8")

Figure 2D¶

In [13]:
getSpyderPlot(results_whole,"result_test","TCGA-BRCA-mRNA","#8CBFAA")